from .protonet import ProtoNet
from .protonet_metric import ProtoNetAcc
from .protonet_loss import ProtoNetLoss
from lib.sampler import EpisodeSampler
from lib.algorithms import ALGORITHMS


@ALGORITHMS.register('protonet')
def protonet(cfg):
    return {
        'net': ProtoNet(cfg),
        'loss': ProtoNetLoss(cfg),
        'metric': ProtoNetAcc(cfg),
        'train_sampler': EpisodeSampler,
        'validate_sampler': EpisodeSampler,
        'test_sampler': EpisodeSampler
    }
